import torch
import torch.nn as nn
from torch.autograd import Variable
EPS = 1e-20


class OptTrans(nn.Module):
    def __init__(self, config, ch_x, spatial_x=-1, ch_y=-1, spatial_y=-1,
                 epsilon=1., L=5, remove_bias=False, C_form='cosine', no_bp_P_L=True, skip_critic=False):

        super(OptTrans, self).__init__()
        self.config = config
        self.epsilon = 1./epsilon
        self.L = L
        self.remove_bias = remove_bias
        self.no_bp_P_L = no_bp_P_L
        self.C_form = C_form
        self.skip_critic = skip_critic
        self.two_dim = spatial_x > 1

        ch_y = ch_x if ch_y == -1 else ch_y
        spatial_y = spatial_x if spatial_y == -1 else spatial_y

        # define G_net
        if self.two_dim:
            if spatial_x != spatial_y:
                stride, out_pad = 2, 1   # upsample
            else:
                stride, out_pad = 1, 0   # keep spatial size unchanged
            _G_net_list = nn.ModuleList([
                nn.ConvTranspose2d(ch_x, ch_y, kernel_size=3, padding=1, stride=stride, output_padding=out_pad),
                nn.BatchNorm2d(ch_y),
                nn.ReLU(),
            ])
            self.G_net = nn.Sequential(*_G_net_list)
        else:
            self.G_net = nn.Sequential(*[
                nn.Conv1d(ch_x, ch_y, kernel_size=3, padding=1, stride=1),
                # nn.BatchNorm1d(ch_y),
                nn.ReLU()
            ])

        # define critic
        if not self.skip_critic:
            # this is a 2D case
            if self.two_dim:
                self.critic = nn.Sequential(*[
                    nn.Conv2d(ch_y, int(ch_y/2), kernel_size=3, padding=1, stride=2),
                    nn.BatchNorm2d(int(ch_y/2)),
                    nn.ReLU(),
                    nn.Conv2d(int(ch_y/2), int(ch_y/4), kernel_size=3, padding=1, stride=2),
                    nn.BatchNorm2d(int(ch_y/4),),
                    nn.ReLU(),
                ])
            else:
                # 1D case
                if self.config.DEV.OT_ONE_DIM_FORM == 'conv':
                    self.critic = nn.Sequential(*[
                        nn.Conv1d(ch_y, int(ch_y/4), kernel_size=3, padding=1, stride=1),
                        # comment BN layer; since 1(only one sample)x1024x1 will report error
                        # nn.BatchNorm1d(int(ch_y/4)),
                        nn.ReLU()
                    ])
                elif self.config.DEV.OT_ONE_DIM_FORM == 'fc':
                    self.critic = nn.Linear(ch_y, int(ch_y/8))

    def forward(self, x, y):
        """
        x_upsample is generated by latent variable x (or z); y is ground truth.
            One-dim case:
                x shape (small feature):  say 15 x 1024 x 1
                y shape (big feature): same as 15 x 1024 x 1; it should be detached already.
        """
        x_upsample = self.G_net(x)
        if self.remove_bias:
            loss = self._basic_compute_loss(x_upsample, y)
        else:
            loss = 2*self._basic_compute_loss(x_upsample, y) \
                    - self._basic_compute_loss(x_upsample, x_upsample) \
                    - self._basic_compute_loss(y, y)
        return loss

    def _basic_compute_loss(self, x, y):
        bs = x.size(0)

        loss = []
        # if self.skip_critic:
        #     x = x.view(bs, -1)
        #     y = y.view(bs, -1)
        # else:
        #     if self.config.DEV.LOSS_CHOICE == 'ot' \
        #             and self.config.DEV.OT_ONE_DIM_FORM == 'fc':
        #         x = x.view(bs, -1)
        #         y = y.view(bs, -1)
        x = self.critic(x)
        x_all = x.view(bs, x.size(1), -1)  # bs, channel_num, spatial_dim*spatial_dim
        y = self.critic(y)
        y_all = y.view(bs, y.size(1), -1)

        for i in range(bs):
            loss.append(self._sinkhorn_iterate(x_all[i].squeeze(dim=0), y_all[i].squeeze(dim=0)))
        return torch.stack(loss)

    def _sinkhorn_iterate(self, x, y):
        sample_num = x.size(0)
        if self.C_form == 'l2':
            x = x.unsqueeze(dim=2).repeat(1, 1, sample_num)
            y = y.permute(1, 0).unsqueeze(dim=0)
            C = torch.norm((x - y), p=2, dim=1)  # C: i, j where i, j are samples
        elif self.C_form == 'cosine':
            x /= (torch.norm(x, p=2, dim=1, keepdim=True) + EPS)
            y /= (torch.norm(y, p=2, dim=1, keepdim=True) + EPS)
            C = 1 - torch.mm(x, y.permute(1, 0))
            # (Note from capsule project) C is slightly negative for some i, j

        K = torch.exp(-self.epsilon*C)
        # Sinkhorn iterate
        b = Variable(torch.ones(sample_num, 1)*(1./sample_num), requires_grad=True).cuda()
        const = Variable(torch.ones(sample_num, 1)*(1./sample_num), requires_grad=False).cuda()
        for i in range(self.L):
            a = const / (torch.mm(K, b) + EPS)
            b = const / (torch.mm(K.permute(1, 0), a) + EPS)
            # print('L={:d}, a_min={:.6f}, a_max={:.6f}, a_mean={:.6f}, a_std={:.6f}'
            #       '\tb_min={:.6f}, b_max={:.6f}'.format(
            #         i, a.data.min(), a.data.max(),
            #         torch.mean(a).data[0], torch.std(a).data[0],
            #         b.data.min(), b.data.max()))

        K = a*K*b.permute(1, 0)
        if self.no_bp_P_L:
            K = K.detach()
        # dot product of two matrices:
        # torch.sum(torch.mul())
        basic_loss = torch.dot(K.view(-1), C.view(-1))
        return basic_loss
